Skip to content

Perception Encoder Integration #2478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

berniebear
Copy link

@berniebear berniebear commented Apr 25, 2025

Add Perception Encoder to timm.

Intro

This PR aims to integrate Perception Encoder (paper, code) from FAIR to timm. We thank you for the support and feedback.

Perception Encoder Performance:

Vision-Language Benchmarks

Model Checkpoint IN-1k IN-v2 IN-A ObjectNet COCO-T2I Kinetics-400 VTT-T2I
B/16 224px vit_pe_core_base_patch16_224 78.4 71.7 62.4 71.9 50.9 65.6 47.6
L/14 336px vit_pe_core_large_patch14_336 83.5 77.9 89.0 84.7 57.1 73.4 50.3
G/14 448px vit_pe_core_gigantic_patch14_448 85.4 80.2 92.6 88.2 58.1 76.9 51.2

Multimodal LLM Benchmarks

Encoder Checkpoint Doc VQA InfoQA TextVQA MVBench PerceptionTest EgoSchema
L/14 448px vit_pe_lang_large_patch14_448 81.9 46.4 73.0 52.3 54.7 59.8
G/14 448px vit_pe_lang_gigantic_patch14_448 84.4 48.3 75.2 52.4 56.0 62.0

Vision-centric Benchmarks

Encoder Checkpoint ADE20k
Linear Probe
448px w/o TTA
LVIS
Mask R-CNN 1024px
Box / Mask mAP
COCO
DETA 1824px
Box mAP
G/14 448px vit_pe_spatial_gigantic_patch14_448 49.3 54.2 / 49.3 66.0

Proposed integration and changes:

  1. Add pe models in pe.py to timm/models/
  2. Load pe modules in timm/models/init.py
  3. Process PE checkpoints on HF hub (eg facebook/vit_pe_core_base_patch16_224_timm) into the safe tensor format to be loadable in timm (via push_to_hub, suggested by NielsRogge)

Known issues/limitations:

  1. PE's rope is not compatible to timm's layer. Using PE's rope implementation in pe.py for now.
  2. PE's vision transformer (PE class) is customized to use both absolute pos_emb and rope.
  3. The ckpts are in facebook's hf_hub. Need to be copied to timm's hf_hub and remove pre-trained.
  4. Currently ViT only for timm. The text transformer to be integrated in the open_clip repo later.
  5. For PE inference/fine-tuning only. No PE pre-training from scratch support yet (e.g. no progressive resolution/ metaclip curation yet).

PE models available hf_hub path

A. ViT only

  1. vit_pe_core_base_patch16_224: facebook/vit_pe_core_base_patch16_224_timm
  2. vit_pe_core_large_patch14_336: facebook/vit_pe_core_large_patch14_336_timm
  3. vit_pe_core_gigantic_patch14_448: facebook/vit_pe_core_gigantic_patch14_448_timm
  4. vit_pe_lang_large_patch14_448: facebook/vit_pe_lang_large_patch14_448_timm
  5. vit_pe_lang_gigantic_patch14_448: facebook/vit_pe_lang_gigantic_patch14_448_timm
  6. vit_pe_spatial_gigantic_patch14_448: facebook/vit_pe_spatial_gigantic_patch14_448_timm

B. CLIP (ViT + Text transformer. For future open_clip integration only)

  1. pe_core_base_patch16_224: facebook/pe_core_base_patch16_224_timm
  2. pe_core_large_patch14_336: facebook/pe_core_large_patch14_336_timm
  3. pe_core_gigantic_patch14_448: facebook/pe_core_gigantic_patch14_448_timm

Test plan (parity):

import torch
import os, sys
from PIL import Image
import timm

## timm model
model_timm = timm.create_model('vit_pe_core_large_patch14_336', pretrained=True, pretrained_cfg = {'hf_hub_id':'facebook/vit_pe_core_large_patch14_336_timm'})
model_timm = model_timm.cuda()

import core.vision_encoder.pe as pe
import core.vision_encoder.transforms as transforms

## original pe model
model_pe = pe.VisionTransformer.from_config("PE-Core-L14-336", pretrained=True)  # Downloads from HF
model_pe = model_pe.cuda()

preprocess = transforms.get_image_transform(model_pe.image_size)
image = preprocess(Image.open("./apps/pe/docs/assets/cat.png")).unsqueeze(0).cuda()

feat_pe = model_pe(image).detach().cpu().numpy()
feat_timm = model_timm(image).detach().cpu().numpy()
print('feat_pe', feat_pe) # [[ 0.8944705   0.32723966 -0.83092093 ... -0.4582289  -0.76679176 -0.29771212]] 
print('feat_pe.shape', feat_pe.shape) # (1, 1024)
print('feat_timm', feat_timm) # [[ 0.8944705   0.32723966 -0.83092093 ... -0.4582289  -0.76679176 -0.29771212]] 
print('feat_timm.shape', feat_timm.shape) # (1, 1024)

All the models supported and tested:

model_timm = timm.create_model('vit_pe_core_base_patch16_224', pretrained=True, pretrained_cfg = {'hf_hub_id':'facebook/vit_pe_core_base_patch16_224_timm'})
model_timm = timm.create_model('vit_pe_core_large_patch14_336', pretrained=True, pretrained_cfg = {'hf_hub_id':'facebook/vit_pe_core_large_patch14_336_timm'})
model_timm = timm.create_model('vit_pe_core_gigantic_patch14_448', pretrained=True, pretrained_cfg = {'hf_hub_id':'facebook/vit_pe_core_gigantic_patch14_448_timm'})
model_timm = timm.create_model('vit_pe_lang_gigantic_patch14_448', pretrained=True, pretrained_cfg = {'hf_hub_id':'facebook/vit_pe_lang_gigantic_patch14_448_timm'})
model_timm = timm.create_model('vit_pe_lang_large_patch14_448', pretrained=True, pretrained_cfg = {'hf_hub_id':'facebook/vit_pe_lang_large_patch14_448_timm'})
model_timm = timm.create_model('vit_pe_spatial_gigantic_patch14_448', pretrained=True, pretrained_cfg = {'hf_hub_id':'facebook/vit_pe_spatial_gigantic_patch14_448_timm'})

Note:

  1. The timm model starts with vit prefix contains only ViT weights (eg vit_pe_core_large_patch14_336).
  2. The PE CLIP checkpoints are placeholder for open_clip after timm integration (e.g. pe_core_gigantic_patch14_448).

Thanks for all the support and feedback for this timm integration!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

elif freqs_for == "lang":
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pi here should load from torch

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, prefer to keep torch.pi vs math.pi and not import x.pi as pi ...

elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The freqs here is a parameter that isn't in the original model, so their are complaints about this when loading state dict... I assume the behaviour in the pretrained model still matches current code? But for the option of having learned_freq, should this be...

        theta *= theta_rescale_factor ** (dim / (dim - 2))
        if freqs_for == "lang":
            freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        elif freqs_for == "pixel":
            freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
        elif freqs_for == "constant":
            freqs = torch.ones(num_freqs).float()
        else:
            assert False
        if learned_freq:
            self.freqs = nn.Parameter(freqs)
        else:
            self.freqs = nn.Buffer(freqs, persistent=False)

elif freqs_for == "lang":
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, prefer to keep torch.pi vs math.pi and not import x.pi as pi ...

attn_pooler_heads: int = 8,
pool_type: Literal["attn", "tok", "avg", "none"] = "attn",
num_classes: int = 0, # no use for PE
in_chans: int = 3,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do need to add support for a classifier, either in the PE module or wrap everything, otherwise default behaviour for adapting encoders as classifiers doesn't work so well ... I'll figure out how best to support

Copy link
Author

@berniebear berniebear Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add classifier support (and reset) in the new commit. Current forward pass: [x -> Transformer(x)] -> [pool -> proj -> head (for classification)], with forward_features and forward_head respectively. Let's discuss more in the Slack (hf-fair-pe-collab). Thank you!


self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 should be in_chans



class Rope2D(Module):
def __init__(self, dim, grid_size, use_cls_token=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module should be marked non traceable to pass FX tests as it looks like the if t.ndim == 3 will break tracing

See eg

@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
for use of notrace decorator

freq = torch.cat([freq, torch.zeros(1, freq.shape[-1])], dim=0)

self.freq = Parameter(freq[None, ...]) # remark: using Parameter instead of tensor for device consistency

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also complaint about this parameter, was it originally not a parameter as it doesn't exist in state dicts.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed rope freq to nn.Buffer(freqs, persistent=False). Thanks for the suggestion.

@rwightman
Copy link
Collaborator

@berniebear sorry, silly typo in my comments that wasn't in my working hacks, its self.register_buffer not nn.Buffer, haha ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants